import warnings

warnings.filterwarnings("ignore")

import os
from functools import partial

import copy
import gym
import jax
import jax.numpy as jnp
import numpy as np
import tqdm
from absl import app, flags, logging
from flax.training import checkpoints
from ml_collections import config_flags

from configs.ensemble_config import add_redq_config
from soar.agents import agents
from soar.common.evaluation import evaluate_with_trajectories
from soar.common.wandb import WandBLogger
from soar.data.replay_buffer import ReplayBuffer, ReplayBufferMC
from soar.envs.adroit_binary_dataset import get_hand_dataset_with_mc_calculation
from soar.envs.d4rl_dataset import (
    get_d4rl_dataset,
    get_d4rl_dataset_with_mc_calculation,
)
from soar.envs.env_common import get_env_type, make_gym_env
from soar.utils.timer_utils import Timer
from soar.utils.train_utils import concatenate_batches, subsample_batch
from d4rl import kitchen

FLAGS = flags.FLAGS

# env
flags.DEFINE_string("env", "antmaze-large-diverse-v2", "Environemnt to use")
flags.DEFINE_float("reward_scale", 1.0, "Reward scale.")
flags.DEFINE_float("reward_bias", 0.0, "Reward bias.")
flags.DEFINE_float(
    "clip_action",
    0.99999,
    "Clip actions to be between [-n, n]. This is needed for tanh policies.",
)

# training
flags.DEFINE_integer("num_offline_steps", 1_000_000, "Number of offline epochs.")
flags.DEFINE_integer("num_online_steps", 200_000, "Number of online epochs.")
flags.DEFINE_integer(
    "kl_divergence_interval",
    5_000,
    "Interval to compute KL divergence between the current agent and the pretrained agent",
)
flags.DEFINE_float(
    "offline_data_ratio",
    0.0,
    "How much offline data to retain in each online batch update",
)
flags.DEFINE_float(
    "warmup_offline_data_ratio",
    0.0,
    "How much offline data to retain in each online batch update",
)
flags.DEFINE_string(
    "online_sampling_method",
    "mixed",
    """Method of sampling data during online update: mixed or append.
    `mixed` samples from a mix of offline and online data according to offline_data_ratio.
    `append` adds offline data to replay buffer and samples from it.""",
)
flags.DEFINE_bool(
    "online_use_cql_loss",
    True,
    """When agent is CQL/CalQL, whether to use CQL loss for the online phase (use SAC loss if False)""",
)
flags.DEFINE_integer(
    "warmup_steps", 0, "number of warmup steps before performing online updates"
)

# agent
flags.DEFINE_string("agent", "calql", "what RL agent to use")
flags.DEFINE_integer("utd", 1, "update-to-data ratio of the critic")
flags.DEFINE_integer("batch_size", 256, "batch size for training")
flags.DEFINE_integer("replay_buffer_capacity", int(2e6), "Replay buffer capacity")
flags.DEFINE_bool("use_redq", True, "Use an ensemble of Q-functions for the agent")

# experiment house keeping
flags.DEFINE_integer("seed", 10, "Random seed.")
flags.DEFINE_string(
    "save_dir",
    os.path.expanduser("~/soar_log"),
    "Directory to save the logs and checkpoints",
)
flags.DEFINE_string("resume_path", "", "Path to resume from")
flags.DEFINE_string("finetuned_agent_resume_path", "", "Path to resume from")
flags.DEFINE_integer("log_interval", 5_000, "Log every n steps")
flags.DEFINE_integer("eval_interval", 5_000, "Evaluate every n steps")
flags.DEFINE_integer("save_interval", 100_000, "Save every n steps.")
flags.DEFINE_integer(
    "n_eval_trajs", 20, "Number of trajectories to use for each evaluation."
)
flags.DEFINE_bool("deterministic_eval", True, "Whether to use deterministic evaluation")

# wandb
flags.DEFINE_string("exp_name", "", "Experiment name for wandb logging")
flags.DEFINE_string("project", None, "Wandb project folder")
flags.DEFINE_string("group", None, "Wandb group of the experiment")
flags.DEFINE_bool("debug", False, "If true, no logging to wandb")
flags.DEFINE_bool(
    "resume_actor_only", False, "Whether to resume only the actor network"
)
flags.DEFINE_bool(
    "train_critic_during_warmup",
    False,
    "Whether to train the critic during warmup. If False, no network updates are performed during warmup.",
)

# offline data ratio annealing
flags.DEFINE_bool(
    "use_offline_data_ratio_annealing", False, "Whether to use offline data annealing"
)
flags.DEFINE_float(
    "min_offline_data_ratio",
    0.0,
    "Minimum offline data ratio during warmup if using offline data annealing",
)
flags.DEFINE_float(
    "max_offline_data_ratio",
    1.0,
    "Maximum offline data ratio during warmup if using offline data annealing",
)
flags.DEFINE_float(
    "anneal_interval",
    20000,
    "Interval (in steps) to anneal the offline data ratio",
)
flags.DEFINE_string(
    "offline_data_ratio_anneal_method",
    "linear",
    "Method to anneal offline data ratio. Options are linear, exponential, and cosine",
)
flags.DEFINE_float(
    "offline_data_ratio_exponential_decay_rate",
    5.0,
    "Decay rate for exponential offline data ratio annealing",
)
# conservative penalty annealing
flags.DEFINE_bool(
    "conservative_penalty_annealing",
    False,
    "Whether to linearly anneal the conservative penalty (CQL alpha)",
)
flags.DEFINE_float(
    "min_cql_alpha",
    0.0,
    "Minimum CQL alpha (conservative penalty) during online training if using conservative penalty annealing",
)
flags.DEFINE_float(
    "cql_alpha_anneal_interval",
    20000,
    "Interval (in steps) to anneal the CQL alpha (conservative penalty)",
)
flags.DEFINE_string(
    "cql_alpha_anneal_method",
    "linear",
    "Method to anneal CQL alpha (conservative penalty). Options are linear, exponential, and cosine",
)
flags.DEFINE_float(
    "cql_alpha_exponential_decay_rate",
    5.0,
    "Decay rate for exponential offline data ratio annealing",
)
flags.DEFINE_float(
    "cql_alpha_power",
    3.0,
    "Power for power annealing of CQL alpha (conservative penalty)",
)
flags.DEFINE_bool(
    "save_final_checkpoint",
    False,
    "Whether to save the final checkpoint at the end of training",
)
config_flags.DEFINE_config_file(
    "config",
    None,
    "File path to the training hyperparameter configuration.",
    lock_config=False,
)


def main(_):
    """
    house keeping
    """
    assert FLAGS.online_sampling_method in [
        "mixed",
        "append",
    ], "incorrect online sampling method"

    if FLAGS.use_redq:
        FLAGS.config.agent_kwargs = add_redq_config(FLAGS.config.agent_kwargs)

    min_steps_to_update = FLAGS.batch_size * (1 - FLAGS.offline_data_ratio)
    if FLAGS.agent == "calql":
        min_steps_to_update = max(
            min_steps_to_update, gym.make(FLAGS.env)._max_episode_steps
        )

    is_resume = FLAGS.resume_path != ""
    is_finetuned_agent_resume = FLAGS.finetuned_agent_resume_path != ""

    """
    wandb and logging
    """
    wandb_config = WandBLogger.get_default_config()
    wandb_config.update(
        {
            "name": FLAGS.exp_name,
            "project": "soar" or FLAGS.project,
            "group": "soar" or FLAGS.group,
            "exp_descriptor": f"{FLAGS.exp_name}_{FLAGS.env}_{FLAGS.agent}_seed{FLAGS.seed}",
            "agent": FLAGS.agent,
            "env": FLAGS.env,
            "resume": "resume" if is_resume else "from_scratch",
        }
    )
    wandb_logger = WandBLogger(
        wandb_config=wandb_config,
        variant=FLAGS.config.to_dict(),
        random_str_in_identifier=True,
        disable_online_logging=FLAGS.debug,
    )

    save_dir = os.path.join(
        FLAGS.save_dir,
        wandb_logger.config.project,
        f"{wandb_logger.config.exp_descriptor}_{wandb_logger.config.unique_identifier}",
    )

    """
    env
    """
    # do not clip adroit actions online following CalQL repo
    # https://github.com/nakamotoo/Cal-QL
    env_type = get_env_type(FLAGS.env)
    finetune_env = make_gym_env(
        env_name=FLAGS.env,
        reward_scale=FLAGS.reward_scale,
        reward_bias=FLAGS.reward_bias,
        scale_and_clip_action=env_type in ("antmaze", "kitchen", "locomotion"),
        action_clip_lim=FLAGS.clip_action,
        seed=FLAGS.seed,
    )
    eval_env = make_gym_env(
        env_name=FLAGS.env,
        scale_and_clip_action=env_type in ("antmaze", "kitchen", "locomotion"),
        action_clip_lim=FLAGS.clip_action,
        seed=FLAGS.seed + 1000,
    )

    """
    load dataset
    """
    if env_type == "adroit-binary":
        dataset = get_hand_dataset_with_mc_calculation(
            FLAGS.env,
            gamma=FLAGS.config.agent_kwargs.discount,
            reward_scale=FLAGS.reward_scale,
            reward_bias=FLAGS.reward_bias,
            clip_action=FLAGS.clip_action,
        )
    else:
        if FLAGS.agent == "calql":
            # need dataset with mc return
            dataset = get_d4rl_dataset_with_mc_calculation(
                FLAGS.env,
                reward_scale=FLAGS.reward_scale,
                reward_bias=FLAGS.reward_bias,
                clip_action=FLAGS.clip_action,
                gamma=FLAGS.config.agent_kwargs.discount,
            )
        else:
            dataset = get_d4rl_dataset(
                FLAGS.env,
                reward_scale=FLAGS.reward_scale,
                reward_bias=FLAGS.reward_bias,
                clip_action=FLAGS.clip_action,
            )

    """
    replay buffer
    """
    replay_buffer_type = ReplayBufferMC if FLAGS.agent == "calql" else ReplayBuffer
    replay_buffer = replay_buffer_type(
        finetune_env.observation_space,
        finetune_env.action_space,
        capacity=FLAGS.replay_buffer_capacity,
        seed=FLAGS.seed,
        discount=FLAGS.config.agent_kwargs.discount if FLAGS.agent == "calql" else None,
    )

    """
    Initialize agent
    """
    rng = jax.random.PRNGKey(FLAGS.seed)
    rng, construct_rng = jax.random.split(rng)
    example_batch = subsample_batch(dataset, FLAGS.batch_size)
    agent = agents[FLAGS.agent].create(
        rng=construct_rng,
        observations=example_batch["observations"],
        actions=example_batch["actions"],
        encoder_def=None,
        **FLAGS.config.agent_kwargs,
    )

    if is_resume:
        # get prefix from environment variable
        RESUME_PATH_PREFIX = os.getenv("RESUME_PATH_PREFIX", "/home/ubuntu/soar_log/")
        FLAGS.resume_path = os.path.join(RESUME_PATH_PREFIX, FLAGS.resume_path)
        assert os.path.exists(FLAGS.resume_path), "resume path does not exist"

        logging.info("Resuming the entire agent")
        agent = checkpoints.restore_checkpoint(FLAGS.resume_path, target=agent)

        pretrained_agent = copy.deepcopy(agent)

        if FLAGS.resume_actor_only:
            # load only actor network
            logging.info("Resuming only the actor network")
            agent.initialize_critic(
                rng=jax.random.PRNGKey(FLAGS.seed),
                observations=example_batch["observations"],
                actions=example_batch["actions"],
            )

    if is_finetuned_agent_resume:
        # get prefix from environment variable
        RESUME_PATH_PREFIX = os.getenv("RESUME_PATH_PREFIX", "/home/ubuntu/soar_log/")
        FLAGS.finetuned_agent_resume_path = os.path.join(
            RESUME_PATH_PREFIX, FLAGS.finetuned_agent_resume_path
        )
        assert os.path.exists(
            FLAGS.finetuned_agent_resume_path
        ), "finetuned agent resume path does not exist"

        finetuned_agent = copy.deepcopy(agent)
        logging.info("Resuming the finetuned agent")
        finetuned_agent = checkpoints.restore_checkpoint(
            FLAGS.finetuned_agent_resume_path, target=finetuned_agent
        )

    """
    eval function
    """

    def evaluate_and_log_results(
        eval_env,
        policy_fn,
        eval_func,
        step_number,
        wandb_logger,
        n_eval_trajs=FLAGS.n_eval_trajs,
    ):
        stats, trajs = eval_func(
            policy_fn,
            eval_env,
            n_eval_trajs,
        )

        eval_info = {
            "average_return": np.mean([np.sum(t["rewards"]) for t in trajs]),
            "average_traj_length": np.mean([len(t["rewards"]) for t in trajs]),
        }
        if env_type == "adroit-binary":
            # adroit
            eval_info["success_rate"] = np.mean(
                [any(d["goal_achieved"] for d in t["infos"]) for t in trajs]
            )
        elif env_type == "kitchen":
            # kitchen
            eval_info["num_stages_solved"] = np.mean([t["rewards"][-1] for t in trajs])
            eval_info["success_rate"] = np.mean([t["rewards"][-1] for t in trajs]) / 4
        else:
            # d4rl antmaze, locomotion
            eval_info["success_rate"] = eval_info["average_normalized_return"] = (
                np.mean(
                    [eval_env.get_normalized_score(np.sum(t["rewards"])) for t in trajs]
                )
            )

        wandb_logger.log({"evaluation": eval_info}, step=step_number)

    """
    training loop
    """
    timer = Timer()
    step = int(agent.state.step)  # 0 for new agents, or load from pre-trained
    is_online_stage = False
    observation, info = finetune_env.reset()
    done = False  # env done signal

    is_cql_alpha_annealed = False
    cql_alpha = 0.0

    if FLAGS.agent in ("cql", "calql"):
        cql_alpha = FLAGS.config.agent_kwargs.get("online_cql_alpha", 0.0)

    for _ in tqdm.tqdm(range(step, FLAGS.num_offline_steps + FLAGS.num_online_steps)):
        """
        Switch from offline to online
        """
        if not is_online_stage and step >= FLAGS.num_offline_steps:
            logging.info("Switching to online training")
            is_online_stage = True

            # upload offline data to online buffer
            if FLAGS.online_sampling_method == "append":
                offline_dataset_size = dataset["actions"].shape[0]
                dataset_items = dataset.items()
                for j in range(offline_dataset_size):
                    transition = {k: v[j] for k, v in dataset_items}
                    replay_buffer.insert(transition)

            # option for CQL and CalQL to change the online alpha, and whether to use CQL regularizer
            if FLAGS.agent in ("cql", "calql"):
                online_agent_configs = {
                    "cql_alpha": FLAGS.config.agent_kwargs.get(
                        "online_cql_alpha", None
                    ),
                    "use_cql_loss": FLAGS.online_use_cql_loss,
                }
                agent.update_config(online_agent_configs)

        timer.tick("total")

        """
        Env Step
        """
        with timer.context("env step"):
            if is_online_stage:
                rng, action_rng = jax.random.split(rng)
                action = agent.sample_actions(observation, seed=action_rng)
                next_observation, reward, done, truncated, info = finetune_env.step(
                    action
                )

                transition = dict(
                    observations=observation,
                    next_observations=next_observation,
                    actions=action,
                    rewards=reward,
                    masks=1.0 - done,
                    dones=1.0 if (done or truncated) else 0,
                )
                replay_buffer.insert(transition)

                observation = next_observation
                if done or truncated:
                    observation, info = finetune_env.reset()
                    done = False

        """
        Updates
        """
        with timer.context("update"):
            # offline updates
            if not is_online_stage:
                batch = subsample_batch(dataset, FLAGS.batch_size)
                agent, update_info = agent.update(
                    batch,
                )

            # online updates
            else:
                # anneal warmup offline data ratio based on online steps

                num_online_steps = step - FLAGS.num_offline_steps - FLAGS.warmup_steps
                if num_online_steps > 0 and FLAGS.use_offline_data_ratio_annealing:
                    if num_online_steps <= FLAGS.anneal_interval:
                        if FLAGS.offline_data_ratio_anneal_method == "linear":
                            FLAGS.offline_data_ratio = FLAGS.max_offline_data_ratio - (
                                (
                                    FLAGS.max_offline_data_ratio
                                    - FLAGS.min_offline_data_ratio
                                )
                                * (num_online_steps / FLAGS.anneal_interval)
                            )
                        elif FLAGS.offline_data_ratio_anneal_method == "exponential":
                            decay_rate = FLAGS.offline_data_ratio_exponential_decay_rate
                            FLAGS.offline_data_ratio = max(
                                FLAGS.min_offline_data_ratio,
                                FLAGS.max_offline_data_ratio
                                * np.exp(
                                    -decay_rate
                                    * (num_online_steps / FLAGS.anneal_interval)
                                ),
                            )
                        elif FLAGS.offline_data_ratio_anneal_method == "cosine":
                            FLAGS.offline_data_ratio = (
                                FLAGS.max_offline_data_ratio
                                * 0.5
                                * (
                                    1
                                    + np.cos(
                                        np.pi
                                        * (num_online_steps / FLAGS.anneal_interval)
                                    )
                                )
                            )

                        else:
                            raise RuntimeError(
                                "Incorrect offline data ratio annealing method"
                            )
                    wandb_logger.log(
                        {"offline_data_ratio": FLAGS.offline_data_ratio},
                        step=step,
                    )

                if num_online_steps > 0 and FLAGS.conservative_penalty_annealing:
                    if num_online_steps <= FLAGS.cql_alpha_anneal_interval:
                        online_cql_alpha = FLAGS.config.agent_kwargs.get(
                            "online_cql_alpha", 0.0
                        )
                        if FLAGS.cql_alpha_anneal_method == "linear":
                            cql_alpha = online_cql_alpha - (
                                (online_cql_alpha - FLAGS.min_cql_alpha)
                                * (num_online_steps / FLAGS.cql_alpha_anneal_interval)
                            )
                        elif FLAGS.cql_alpha_anneal_method == "exponential":
                            decay_rate = FLAGS.cql_alpha_exponential_decay_rate
                            cql_alpha = max(
                                FLAGS.min_cql_alpha,
                                online_cql_alpha
                                * np.exp(
                                    -decay_rate
                                    * (
                                        num_online_steps
                                        / FLAGS.cql_alpha_anneal_interval
                                    )
                                ),
                            )
                        elif FLAGS.cql_alpha_anneal_method == "cosine":
                            cql_alpha = (
                                online_cql_alpha
                                * 0.5
                                * (
                                    1
                                    + np.cos(
                                        np.pi
                                        * (
                                            num_online_steps
                                            / FLAGS.cql_alpha_anneal_interval
                                        )
                                    )
                                )
                            )
                        elif FLAGS.cql_alpha_anneal_method == "power":
                            power = FLAGS.cql_alpha_power
                            t = num_online_steps
                            T = FLAGS.cql_alpha_anneal_interval
                            normalized_t = t / T
                            cql_alpha = max(
                                FLAGS.min_cql_alpha,
                                online_cql_alpha * (1 - normalized_t**power),
                            )
                        else:
                            raise RuntimeError("Incorrect CQL alpha annealing method")
                    else:
                        cql_alpha = FLAGS.min_cql_alpha

                    wandb_logger.log({"cql_alpha": cql_alpha}, step=step)

                if (
                    FLAGS.agent in ("cql", "calql")
                    and not is_cql_alpha_annealed
                    and cql_alpha <= 0
                ):
                    agent.update_config({"use_cql_loss": False})
                    is_cql_alpha_annealed = True

                if step - FLAGS.num_offline_steps <= min_steps_to_update:
                    pass
                elif step - FLAGS.num_offline_steps <= FLAGS.warmup_steps:
                    if FLAGS.train_critic_during_warmup:
                        # train critic during warmup
                        batch_size_offline = int(
                            FLAGS.batch_size * FLAGS.warmup_offline_data_ratio
                        )
                        batch_size_online = FLAGS.batch_size - batch_size_offline
                        online_batch = replay_buffer.sample(batch_size_online)
                        offline_batch = subsample_batch(dataset, batch_size_offline)
                        # update with the combined batch
                        batch = concatenate_batches([online_batch, offline_batch])
                        if FLAGS.utd > 1:
                            agent, update_info = agent.update_high_utd_critic(
                                batch,
                                utd_ratio=FLAGS.utd,
                            )
                        else:
                            agent, update_info = agent.update_critic(
                                batch,
                            )
                    else:
                        pass

                else:
                    # do online updates, gather batch
                    if FLAGS.online_sampling_method == "mixed":
                        # batch from a mixing ratio of offline and online data
                        batch_size_offline = int(
                            FLAGS.batch_size * FLAGS.offline_data_ratio
                        )
                        batch_size_online = FLAGS.batch_size - batch_size_offline
                        online_batch = replay_buffer.sample(batch_size_online)
                        offline_batch = subsample_batch(dataset, batch_size_offline)
                        # update with the combined batch
                        batch = concatenate_batches([online_batch, offline_batch])
                    elif FLAGS.online_sampling_method == "append":
                        # batch from online replay buffer, with is initialized with offline data
                        batch = replay_buffer.sample(FLAGS.batch_size)
                    else:
                        raise RuntimeError("Incorrect online sampling method")

                    if (
                        FLAGS.conservative_penalty_annealing
                        and not is_cql_alpha_annealed
                    ):
                        # pass down the annealed cql alpha
                        batch["cql_alpha"] = cql_alpha

                    else:
                        pass

                    # update
                    if FLAGS.utd > 1:
                        agent, update_info = agent.update_high_utd(
                            batch,
                            utd_ratio=FLAGS.utd,
                        )
                    else:
                        agent, update_info = agent.update(
                            batch,
                        )

        """
        Advance Step
        """
        step += 1

        """
        Evals
        """
        eval_steps = (
            FLAGS.num_offline_steps,  # finish offline training
            FLAGS.num_offline_steps + 1,  # start of online training
            FLAGS.num_offline_steps + FLAGS.num_online_steps,  # end of online training
        )
        if step % FLAGS.eval_interval == 0 or step in eval_steps:
            logging.info("Evaluating...")
            with timer.context("evaluation"):
                policy_fn = partial(
                    agent.sample_actions, argmax=FLAGS.deterministic_eval
                )
                eval_func = partial(
                    evaluate_with_trajectories, clip_action=FLAGS.clip_action
                )

                evaluate_and_log_results(
                    eval_env=eval_env,
                    policy_fn=policy_fn,
                    eval_func=eval_func,
                    step_number=step,
                    wandb_logger=wandb_logger,
                )

        """
        Save Checkpoint
        """
        if step == FLAGS.num_offline_steps or (
            step >= FLAGS.num_offline_steps + FLAGS.num_online_steps
            and FLAGS.save_final_checkpoint
        ):
            logging.info("Saving checkpoint...")
            checkpoint_path = checkpoints.save_checkpoint(
                save_dir, agent, step=step, keep=30
            )
            logging.info("Saved checkpoint to %s", checkpoint_path)

        timer.tock("total")

        """
        Logging
        """
        if step % FLAGS.log_interval == 0:
            # check if update_info is available (False during warmup)
            if "update_info" in locals():
                update_info = jax.device_get(update_info)
                wandb_logger.log({"training": update_info}, step=step)

            wandb_logger.log({"timer": timer.get_average_times()}, step=step)


if __name__ == "__main__":
    app.run(main)
